﻿using System.Runtime.CompilerServices;
using System.Security.Cryptography;
using System.Text;
using Microsoft.AspNetCore.Cryptography.KeyDerivation;
using Microsoft.AspNetCore.Identity;
using SymmetricAlgorithm = System.Security.Cryptography.SymmetricAlgorithm;

namespace Devonline.Identity;

public class AesProtector : ILookupProtector
{
    private const int _defaultKeySize = 256;
    //private readonly object _locker;
    private readonly Core.SymmetricAlgorithm _defaultAlgorithm = Core.SymmetricAlgorithm.AES_256_CBC;
    private readonly ILookupProtectorKeyRing _keyRing;
    private readonly DataProtectionOptions _options;

    public AesProtector(DataProtectionOptions options, ILookupProtectorKeyRing keyRing)
    {
        _options = options;
        _keyRing = keyRing;
    }

    /// <summary>
    /// protect data by key
    /// </summary>
    /// <param name="keyId"></param>
    /// <param name="data"></param>
    /// <returns></returns>
    public string Protect(string keyId, string data)
    {
        using var encryptionAlgorithm = Aes.Create();
        encryptionAlgorithm.KeySize = _defaultKeySize;
        var masterKey = MasterKey(keyId);

        // Derive a key for encryption from the master key
        encryptionAlgorithm.Key = DerivedEncryptionKey(masterKey, encryptionAlgorithm);

        // As we need this to be deterministic, we need to force an IV that is derived from the plain text.
        encryptionAlgorithm.IV = DerivedInitializationVector(data, masterKey, encryptionAlgorithm);

        // And encrypt
        using var ms = new MemoryStream();
        using var cs = new CryptoStream(ms, encryptionAlgorithm.CreateEncryptor(), CryptoStreamMode.Write);

        // Convert the string to bytes, because encryption works on bytes, not strings.
        var plainText = Encoding.UTF8.GetBytes(data);

        cs.Write(plainText);
        cs.FlushFinalBlock();
        var encryptedData = ms.ToArray();
        var cipherTextAndIV = CombineByteArrays(encryptionAlgorithm.IV, encryptedData);

        // Now get a signature for the data so we can detect tampering in situ.
        using var signingAlgorithm = new HMACSHA256();
        byte[] signature = SignData(cipherTextAndIV, masterKey, encryptionAlgorithm, signingAlgorithm);

        // Add the signature to the cipher text.
        var signedData = CombineByteArrays(signature, cipherTextAndIV);

        // Add our algorithm identifier to the combined signature and cipher text.
        var algorithmIdentifier = BitConverter.GetBytes((int)_defaultAlgorithm);
        byte[] output = CombineByteArrays(algorithmIdentifier, signedData);

        // Clean everything up.
        encryptionAlgorithm.Clear();
        signingAlgorithm.Clear();

        // clean plaintext
        Array.Clear(plainText, 0, plainText.Length);

        // Return the results as a string.
        return Convert.ToBase64String(output);
    }

    /// <summary>
    /// unprotect the data by key
    /// </summary>
    /// <param name="keyId"></param>
    /// <param name="data"></param>
    /// <returns></returns>
    public string Unprotect(string keyId, string data)
    {
        var masterKey = MasterKey(keyId);
        byte[] plainText;

        // Take our string and convert it back to bytes.
        var payload = Convert.FromBase64String(data);

        // Read the saved algorithm details and create instances of those algorithms.
        byte[] algorithmIdentifierAsBytes = new byte[4];
        Buffer.BlockCopy(payload, 0, algorithmIdentifierAsBytes, 0, 4);
        var algorithmIdentifier = (Core.SymmetricAlgorithm)BitConverter.ToInt32(algorithmIdentifierAsBytes, 0);
        if (algorithmIdentifier != _defaultAlgorithm)
        {
            throw new CryptographicException("Invalid Symmetric Algorithm.");
        }

        using var encryptionAlgorithm = Aes.Create();
        encryptionAlgorithm.KeySize = _defaultKeySize;

        // Now extract the signature
        using var signingAlgorithm = new HMACSHA256();
        byte[] signature = new byte[signingAlgorithm.HashSize / 8];
        Buffer.BlockCopy(payload, 4, signature, 0, signingAlgorithm.HashSize / 8);

        // And finally grab the rest of the data
        var dataLength = payload.Length - 4 - signature.Length;
        byte[] cipherTextAndIV = new byte[dataLength];
        Buffer.BlockCopy(payload, 4 + signature.Length, cipherTextAndIV, 0, dataLength);

        // Check the signature before anything else is done to detect tampering and avoid
        // oracles.
        byte[] computedSignature = SignData(cipherTextAndIV, masterKey, encryptionAlgorithm, signingAlgorithm);

        if (!ByteArraysEqual(computedSignature, signature))
        {
            throw new CryptographicException("Invalid Signature.");
        }

        signingAlgorithm.Clear();

        // The signature is valid, so now we can work on decrypting the data.
        var ivLength = encryptionAlgorithm.BlockSize / 8;
        byte[] initializationVector = new byte[ivLength];
        byte[] cipherText = new byte[cipherTextAndIV.Length - ivLength];

        // The IV is embedded in the cipher text, so we extract it out.
        Buffer.BlockCopy(cipherTextAndIV, 0, initializationVector, 0, ivLength);

        // Then we get the encrypted data.
        Buffer.BlockCopy(cipherTextAndIV, ivLength, cipherText, 0, cipherTextAndIV.Length - ivLength);

        encryptionAlgorithm.Key = DerivedEncryptionKey(masterKey, encryptionAlgorithm);
        encryptionAlgorithm.IV = initializationVector;

        // Decrypt
        using var ms = new MemoryStream();
        using var cs = new CryptoStream(ms, encryptionAlgorithm.CreateDecryptor(), CryptoStreamMode.Write);
        cs.Write(cipherText);
        cs.FlushFinalBlock();
        plainText = ms.ToArray();
        encryptionAlgorithm.Clear();

        // And convert from the bytes back to a string.
        return Encoding.UTF8.GetString(plainText);
    }

    /// <summary>
    /// get the master key by id
    /// </summary>
    /// <param name="keyId"></param>
    /// <returns></returns>
    public byte[] MasterKey(string keyId) => Convert.FromBase64String(_keyRing[keyId]);

    /// <summary>
    /// 签名
    /// </summary>
    /// <param name="cipherText"></param>
    /// <param name="masterKey"></param>
    /// <param name="symmetricAlgorithm"></param>
    /// <param name="hashAlgorithm"></param>
    /// <returns></returns>
    private byte[] SignData(byte[] cipherText, byte[] masterKey, SymmetricAlgorithm symmetricAlgorithm, KeyedHashAlgorithm hashAlgorithm)
    {
        hashAlgorithm.Key = DerivedSigningKey(masterKey, symmetricAlgorithm);
        byte[] signature = hashAlgorithm.ComputeHash(cipherText);
        hashAlgorithm.Clear();
        return signature;
    }

    /// <summary>
    /// 构造签名密钥
    /// </summary>
    /// <param name="key"></param>
    /// <param name="algorithm"></param>
    /// <returns></returns>
    private byte[] DerivedSigningKey(byte[] key, SymmetricAlgorithm algorithm) => KeyDerivation.Pbkdf2(_options.SignPassword, key, KeyDerivationPrf.HMACSHA256, _options.KeyDerivationIterationCount, algorithm.KeySize / 8);

    /// <summary>
    /// 构造密钥
    /// </summary>
    /// <param name="key"></param>
    /// <param name="algorithm"></param>
    /// <returns></returns>
    private byte[] DerivedEncryptionKey(byte[] key, SymmetricAlgorithm algorithm) => KeyDerivation.Pbkdf2(_options.KeyPassword, key, KeyDerivationPrf.HMACSHA256, _options.KeyDerivationIterationCount, algorithm.KeySize / 8);

    /// <summary>
    /// 构造初始化向量
    /// </summary>
    /// <param name="key"></param>
    /// <param name="plainText"></param>
    /// <param name="algorithm"></param>
    /// <returns></returns>
    private byte[] DerivedInitializationVector(string plainText, byte[] key, SymmetricAlgorithm algorithm) => KeyDerivation.Pbkdf2(plainText, key, KeyDerivationPrf.HMACSHA256, _options.KeyDerivationIterationCount, algorithm.BlockSize / 8);

    /// <summary>
    /// 合并字节数组内容
    /// </summary>
    /// <param name="left"></param>
    /// <param name="right"></param>
    /// <returns></returns>
    private static byte[] CombineByteArrays(byte[] left, byte[] right)
    {
        byte[] output = new byte[left.Length + right.Length];
        Buffer.BlockCopy(left, 0, output, 0, left.Length);
        Buffer.BlockCopy(right, 0, output, left.Length, right.Length);
        return output;
    }

    [MethodImpl(MethodImplOptions.NoOptimization)]
    private static bool ByteArraysEqual(byte[] a, byte[] b)
    {
        if (ReferenceEquals(a, b))
        {
            return true;
        }

        if (a == null || b == null || a.Length != b.Length)
        {
            return false;
        }

        bool areSame = true;
        for (int i = 0; i < a.Length; i++)
        {
            areSame &= a[i] == b[i];
        }

        return areSame;
    }
}